{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import numba\n",
    "\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, Optional, List, Callable\n",
    "\n",
    "import torch.distributed as dist\n",
    "\n",
    "import numpy as np\n",
    "import numba\n",
    "\n",
    "\n",
    "@numba.njit\n",
    "def ffd_check(a: np.ndarray, c: int, n: int):\n",
    "    # First-fit-decreasing bin packing\n",
    "    # Check if a[] could fit in n bins with capacity c\n",
    "    # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing\n",
    "\n",
    "    a = np.sort(a)[::-1]\n",
    "    bins = np.full((n, ), c, dtype=a.dtype)\n",
    "    for size in a:\n",
    "        not_found = True\n",
    "        for idx in range(n):\n",
    "            if bins[idx] >= size:\n",
    "                bins[idx] -= size\n",
    "                not_found = False\n",
    "                break\n",
    "\n",
    "        if not_found:\n",
    "            return False\n",
    "\n",
    "    return True\n",
    "\n",
    "\n",
    "@numba.njit\n",
    "def ffd_with_result(a: np.ndarray, c: int, start_index: int):\n",
    "    # First-fit-decreasing bin packing (with result return)\n",
    "\n",
    "    indices = np.argsort(a)[::-1]\n",
    "    a = a[indices]\n",
    "\n",
    "    bins = []\n",
    "    bins_result = []\n",
    "    for a_id, size in enumerate(a):\n",
    "        add_new = True\n",
    "        for idx in range(len(bins)):\n",
    "            if bins[idx] >= size:\n",
    "                bins[idx] -= size\n",
    "                bins_result[idx].append(indices[a_id] + start_index)\n",
    "                add_new = False\n",
    "                break\n",
    "\n",
    "        if add_new:\n",
    "            bins.append(c - size)\n",
    "            bins_result.append([indices[a_id] + start_index])\n",
    "\n",
    "    return bins_result\n",
    "\n",
    "\n",
    "@numba.njit\n",
    "def allocate(lengths: np.ndarray, numseqs: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int):\n",
    "    # Dynamic batch allocator, similar to Multifit\n",
    "    # https://en.wikipedia.org/wiki/Multifit_algorithm\n",
    "    # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)\n",
    "\n",
    "    s = 0\n",
    "    start_index = 0\n",
    "    result = []\n",
    "    result_totseqs = []\n",
    "\n",
    "    while True:\n",
    "        # binary search [l, r)\n",
    "        l = 1\n",
    "        r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, \"right\")\n",
    "\n",
    "        while r - l > 1:\n",
    "            m = (l + r) // 2\n",
    "            if ffd_check(lengths[start_index: start_index + m], c, n):\n",
    "                l = m\n",
    "            else:\n",
    "                r = m\n",
    "\n",
    "        # use length l\n",
    "        batch = ffd_with_result(lengths[start_index: start_index + l], c, start_index)\n",
    "        if len(batch) < n:\n",
    "            break\n",
    "\n",
    "        start_index += l\n",
    "        s = lengths_cumsum[start_index - 1]\n",
    "\n",
    "        # add local rank\n",
    "        result.append(batch[rank])\n",
    "        # add total seqs for all ranks\n",
    "        totseq = 0\n",
    "        for indices in batch:\n",
    "            for idx in indices:\n",
    "                totseq += numseqs[idx]\n",
    "        result_totseqs.append(totseq)\n",
    "\n",
    "    return result, result_totseqs, s, len(result) * c * n\n",
    "\n",
    "\n",
    "class MultipackDistributedDataloader:\n",
    "    \"\"\"Unpadded data loading using Multipack.\n",
    "       Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.\"\"\"\n",
    "    \n",
    "    def __init__(\n",
    "        self,\n",
    "        dataset: Any,\n",
    "        lengths: np.ndarray,\n",
    "        numseqs: np.ndarray,\n",
    "\n",
    "        batch_max_length: int,\n",
    "        collate_fn: Callable,\n",
    "\n",
    "        num_replicas: Optional[int] = None,\n",
    "        rank: Optional[int] = None,\n",
    "\n",
    "        seed: int = 0,\n",
    "    ):\n",
    "        # Dataset\n",
    "        self.dataset = dataset\n",
    "        self.lengths = lengths\n",
    "        self.numseqs = numseqs\n",
    "        assert isinstance(self.lengths, np.ndarray)\n",
    "\n",
    "        self.batch_max_length = batch_max_length\n",
    "        self.collate_fn = collate_fn\n",
    "\n",
    "        # Get rank\n",
    "        if num_replicas is None:\n",
    "            if not dist.is_available():\n",
    "                raise RuntimeError(\"Requires distributed package to be available\")\n",
    "            num_replicas = dist.get_world_size()\n",
    "        if rank is None:\n",
    "            if not dist.is_available():\n",
    "                raise RuntimeError(\"Requires distributed package to be available\")\n",
    "            rank = dist.get_rank()\n",
    "\n",
    "        self.num_replicas = num_replicas\n",
    "        self.rank = rank\n",
    "\n",
    "        # Seed\n",
    "        self.seed = seed\n",
    "\n",
    "        # Epoch\n",
    "        self.epoch = 0\n",
    "\n",
    "        # statistics\n",
    "        self.eff_total_used = 0\n",
    "        self.eff_total_slots = 0\n",
    "\n",
    "    def set_epoch(self, epoch: int):\n",
    "        self.epoch = epoch\n",
    "\n",
    "    def generate_batches(self, set_stats=False):\n",
    "        indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths))\n",
    "\n",
    "        lengths        = self.lengths[indices]\n",
    "        numseqs        = self.numseqs[indices]\n",
    "        lengths_cumsum = np.cumsum(lengths)\n",
    "\n",
    "        batches, totseqs, total_used, total_slots = allocate(lengths=lengths,\n",
    "                                                             numseqs=numseqs,\n",
    "                                                             lengths_cumsum=lengths_cumsum,\n",
    "                                                             rank=self.rank,\n",
    "                                                             c=self.batch_max_length,\n",
    "                                                             n=self.num_replicas)\n",
    "        \n",
    "        curseqs = [np.sum(numseqs[batch]) for batch in batches]\n",
    "        batches = [indices[batch]         for batch in batches]\n",
    "\n",
    "        # statistics\n",
    "        if set_stats:\n",
    "            self.eff_total_used += total_used\n",
    "            self.eff_total_slots += total_slots\n",
    "\n",
    "        return batches, totseqs, curseqs\n",
    "    \n",
    "    def __iter__(self):\n",
    "        all_batches, all_totseqs, all_curseqs = self.generate_batches(set_stats=True)\n",
    "\n",
    "        for batch, totseq, curseq in zip(all_batches, all_totseqs, all_curseqs):\n",
    "            yield self.collate_fn(self.dataset[batch]), totseq, curseq\n",
    "\n",
    "    def num_batches(self):\n",
    "        batches, _, _ = self.generate_batches()\n",
    "        return len(batches)\n",
    "\n",
    "    def efficiency(self):\n",
    "        return self.eff_total_used / self.eff_total_slots\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[3], [2, 0], [1, 4], [5]]\n"
     ]
    }
   ],
   "source": [
    "lengths = np.array([1, 5, 7, 8, 3, 2])\n",
    "lengths_cumsum = np.cumsum(lengths)\n",
    "\n",
    "print(ffd_with_result(lengths, 8, start_index=0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[29, 29, 29, 29, 29, 29, 29, 29]\n",
      "Efficiency: [0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559]\n",
      "Overall Efficiency: 0.9955281976408559\n"
     ]
    }
   ],
   "source": [
    "DATASET = \"../../dataset_processed/openchat.train.json\"\n",
    "C = 14 * 2048\n",
    "N = 8\n",
    "EPOCHS = 10\n",
    "\n",
    "# Load dataset\n",
    "with open(DATASET, \"r\") as f:\n",
    "    dataset = json.load(f)\n",
    "\n",
    "# Check allocator efficiency\n",
    "lengths = np.array([len(tokens) for tokens, masks, group in dataset])\n",
    "numseqs = np.random.randint(low=1, high=10, size=lengths.shape)\n",
    "# lengths = np.random.randint(0, 2048, (int(5e6))).astype(np.int32)\n",
    "\n",
    "# test sampler correctness & efficiency\n",
    "tot_len = 0\n",
    "tot_batches = 0\n",
    "\n",
    "dataloaders = [MultipackDistributedDataloader(dataset=np.arange(len(lengths)), lengths=lengths, numseqs=numseqs,\n",
    "                                              batch_max_length=C, \n",
    "                                              num_replicas=N, rank=rank,\n",
    "                                              collate_fn=lambda x: x) for rank in range(N)]\n",
    "print([loader.num_batches() for loader in dataloaders])\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    batches = []\n",
    "    totseqs = []\n",
    "    curseqs = []\n",
    "\n",
    "    for loader in dataloaders:\n",
    "        loader.set_epoch(epoch)\n",
    "        totseqs.append([])\n",
    "        curseqs.append([])\n",
    "\n",
    "        for batch, totseq, curseq in loader:\n",
    "            batches.extend(batch)\n",
    "\n",
    "            gt_curseq = np.sum(numseqs[batch])\n",
    "            # print (batch, curseq, gt_curseq)\n",
    "            assert gt_curseq == curseq\n",
    "\n",
    "            totseqs[-1].append(totseq)\n",
    "            curseqs[-1].append(gt_curseq)\n",
    "\n",
    "            # Check constraints\n",
    "            overall_len = sum([lengths[x] for x in batch])\n",
    "            assert overall_len <= C\n",
    "\n",
    "            tot_len += overall_len\n",
    "            tot_batches += 1\n",
    "\n",
    "    # Check overall unique\n",
    "    batches.sort()\n",
    "    assert batches == list(set(batches))  # Unique\n",
    "\n",
    "    # Check totseq accurate\n",
    "    gt_totseqs = np.sum(curseqs, axis=0)\n",
    "    for i in range(len(totseqs)):\n",
    "        assert (totseqs[i] == gt_totseqs).all()\n",
    "\n",
    "# Check efficiency\n",
    "efficiency = [loader.efficiency() for loader in dataloaders]\n",
    "print(f\"Efficiency: {efficiency}\")\n",
    "\n",
    "print(f\"Overall Efficiency: {tot_len / (tot_batches * C)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "150.98552224214524"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "C * N / np.mean(lengths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "150.31034482758622"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(gt_totseqs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
